# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

from functools import partial
from itertools import product
from unittest.mock import Mock
from prompt_toolkit.document import Document

from pgsqltoolsservice.language.completion.packages.parseutils.meta import FunctionMetadata, ForeignKey
from pgsqltoolsservice.language.completion import PGCompleter
from pgsqltoolsservice.language.completion.pg_completion import PGCompletion

qual = ['if_more_than_one_table', 'always']
no_qual = ['if_more_than_one_table', 'never']


def escape(name):
    if not name.islower() or name in ('select', 'localtimestamp'):
        return '"' + name + '"'
    return name


def completion(display_meta, text, pos=0):
    return PGCompletion(text, start_position=pos, display_meta=display_meta)


def function(text, pos=0, display=None):
    return PGCompletion(
        text,
        display=display or text,
        start_position=pos,
        display_meta='function'
    )


def get_result(completer, text, position=None):
    position = len(text) if position is None else position
    return completer.get_completions(
        Document(text=text, cursor_position=position), Mock()
    )


def result_set(completer, text, position=None):
    return set(get_result(completer, text, position))


# The code below is quivalent to
# def schema(text, pos=0):
#   return completion('schema', text, pos)
# and so on
schema = partial(completion, 'schema')
table = partial(completion, 'table')
view = partial(completion, 'view')
column = partial(completion, 'column')
keyword = partial(completion, 'keyword')
datatype = partial(completion, 'datatype')
alias = partial(completion, 'table alias')
name_join = partial(completion, 'name join')
fk_join = partial(completion, 'fk join')
join = partial(completion, 'join')


def wildcard_expansion(cols, pos=-1):
    return PGCompletion(
        cols, start_position=pos, display_meta='columns', display='*')


class MetaData(object):
    def __init__(self, metadata):
        self.metadata = metadata

    def builtin_functions(self, pos=0):
        return [function(f, pos) for f in self.completer.functions]

    def builtin_datatypes(self, pos=0):
        return [datatype(dt, pos) for dt in self.completer.datatypes]

    def keywords(self, pos=0):
        return [keyword(kw, pos) for kw in self.completer.keywords_tree.keys()]

    def columns(self, tbl, parent='public', typ='tables', pos=0):
        if typ == 'functions':
            fun = [x for x in self.metadata[typ][parent] if x[0] == tbl][0]
            cols = fun[1]
        else:
            cols = self.metadata[typ][parent][tbl]
        return [column(escape(col), pos) for col in cols]

    def datatypes(self, parent='public', pos=0):
        return [
            datatype(escape(x), pos)
            for x in self.metadata.get('datatypes', {}).get(parent, [])]

    def tables(self, parent='public', pos=0):
        return [
            table(escape(x), pos)
            for x in self.metadata.get('tables', {}).get(parent, [])]

    def views(self, parent='public', pos=0):
        return [
            view(escape(x), pos)
            for x in self.metadata.get('views', {}).get(parent, [])]

    def functions(self, parent='public', pos=0):
        return [
            function(
                escape(x[0]) + '(' + ', '.join(
                    arg_name + ' := '
                    for (arg_name, arg_mode) in zip(x[1], x[3])
                    if arg_mode in ('b', 'i')
                ) + ')',
                pos,
                escape(x[0]) + '(' + ', '.join(
                    arg_name
                    for (arg_name, arg_mode) in zip(x[1], x[3])
                    if arg_mode in ('b', 'i')
                ) + ')'
            )
            for x in self.metadata.get('functions', {}).get(parent, [])
        ]

    def schemas(self, pos=0):
        schemas = set(sch for schs in self.metadata.values() for sch in schs)
        return [schema(escape(s), pos=pos) for s in schemas]

    def functions_and_keywords(self, parent='public', pos=0):
        return (
            self.functions(parent, pos) + self.builtin_functions(pos) +
            self.keywords(pos)
        )

    # Note that the filtering parameters here only apply to the columns
    def columns_functions_and_keywords(
            self, tbl, parent='public', typ='tables', pos=0
    ):
        return (
            self.functions_and_keywords(pos=pos) +
            self.columns(tbl, parent, typ, pos)
        )

    def from_clause_items(self, parent='public', pos=0):
        return (
            self.functions(parent, pos) + self.views(parent, pos) +
            self.tables(parent, pos)
        )

    def schemas_and_from_clause_items(self, parent='public', pos=0):
        return self.from_clause_items(parent, pos) + self.schemas(pos)

    def types(self, parent='public', pos=0):
        return self.datatypes(parent, pos) + self.tables(parent, pos)

    @property
    def completer(self):
        return self.get_completer()

    def get_completers(self, casing):
        """
        Returns a function taking three bools `casing`, `filtr`, `aliasing` and
        the list `qualify`, all defaulting to None.
        Returns a list of completers.
        These parameters specify the allowed values for the corresponding
        completer parameters, `None` meaning any, i.e. (None, None, None, None)
        results in all 24 possible completers, whereas e.g.
        (True, False, True, ['never']) results in the one completer with
        casing, without `search_path` filtering of objects, with table
        aliasing, and without column qualification.
        """
        def _cfg(_casing, filtr, aliasing, qualify):
            cfg = {'settings': {}}
            if _casing:
                cfg['casing'] = casing
            cfg['settings']['search_path_filter'] = filtr
            cfg['settings']['generate_aliases'] = aliasing
            cfg['settings']['qualify_columns'] = qualify
            return cfg

        def _cfgs(casing, filtr, aliasing, qualify):
            casings = [True, False] if casing is None else [casing]
            filtrs = [True, False] if filtr is None else [filtr]
            aliases = [True, False] if aliasing is None else [aliasing]
            qualifys = qualify or ['always', 'if_more_than_one_table', 'never']
            return [
                _cfg(*p) for p in product(casings, filtrs, aliases, qualifys)
            ]

        def completers(casing=None, filtr=None, aliasing=None, qualify=None):
            get_comp = self.get_completer
            return [
                get_comp(**c) for c in _cfgs(casing, filtr, aliasing, qualify)
            ]

        return completers

    def _make_col(self, sch, tbl, col):
        defaults = self.metadata.get('defaults', {}).get(sch, {})
        return (sch, tbl, col, 'text', (tbl, col) in defaults, defaults.get((tbl, col)))

    def get_completer(self, settings=None, casing=None):
        metadata = self.metadata
        comp = PGCompleter(smart_completion=True, settings=settings)

        schemata, tables, tbl_cols, views, view_cols = [], [], [], [], []

        for sch, tbls in metadata['tables'].items():
            schemata.append(sch)

            for tbl, cols in tbls.items():
                tables.append((sch, tbl))
                # Let all columns be text columns
                tbl_cols.extend([self._make_col(sch, tbl, col)
                                 for col in cols])

        for sch, tbls in metadata.get('views', {}).items():
            for tbl, cols in tbls.items():
                views.append((sch, tbl))
                # Let all columns be text columns
                view_cols.extend([self._make_col(sch, tbl, col)
                                  for col in cols])

        functions = [
            FunctionMetadata(sch, *func_meta, arg_defaults=None)
            for sch, funcs in metadata['functions'].items()
            for func_meta in funcs]

        datatypes = [
            (sch, typ)
            for sch, datatypes in metadata['datatypes'].items()
            for typ in datatypes]

        foreignkeys = [
            ForeignKey(*fk) for fks in metadata['foreignkeys'].values()
            for fk in fks]

        comp.extend_schemata(schemata)
        comp.extend_relations(tables, kind='tables')
        comp.extend_relations(views, kind='views')
        comp.extend_columns(tbl_cols, kind='tables')
        comp.extend_columns(view_cols, kind='views')
        comp.extend_functions(functions)
        comp.extend_datatypes(datatypes)
        comp.extend_foreignkeys(foreignkeys)
        comp.set_search_path(['public'])
        comp.extend_casing(casing or [])

        return comp
